package model

import (
	"database/sql"
	"fmt"

	_ "github.com/go-sql-driver/mysql"
	"github.com/spf13/viper"

	log "github.com/sirupsen/logrus"
)

var db *sql.DB

func InitDB() {
	addr := viper.GetString("mysql.addr")
	port := viper.GetString("mysql.port")
	database := viper.GetString("mysql.database")
	username := viper.GetString("mysql.username")
	password := viper.GetString("mysql.password")

	log.Info(fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", username, password, addr, port, database))

	var err error
	db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", username, password, addr, port, database))
	if err != nil {
		log.Fatal(err)
	}

	err = db.Ping()
	if err != nil {
		log.Fatal(err)
	}
}

func GetData(sqlstr string) (data [][]string) {
	rows, err := db.Query(sqlstr)
	if err != nil {
		log.Fatal(err)
	}
	defer rows.Close()

	// 获取列信息
	columns, err := rows.Columns()
	if err != nil {
		log.Fatal(err)
	}

	data = append(data, columns)

	// 为每个字段创建一个指针...
	values := make([]sql.RawBytes, len(columns))
	// ...用来缓存每个字段的值
	scanArgs := make([]interface{}, len(columns))
	for i := range values {
		scanArgs[i] = &values[i]
	}

	// 遍历每一行
	for rows.Next() {
		// 根据指定的列，通过指针扫描行内容
		err = rows.Scan(scanArgs...)
		if err != nil {
			log.Fatal(err)
		}
		// 根据列类型格式化每个字段的值
		da := []string{}
		var v string
		for _, col := range values {
			if col == nil {
				v = "NULL"
			} else {
				v = string(col)
			}
			da = append(da, v)
		}
		data = append(data, da)
	}

	// 检查遍历是否出现错误
	if err = rows.Err(); err != nil {
		log.Fatal(err)
	}

	return
}

// CloseDB 关闭数据库连接
func CloseDB() {
	if db != nil {
		db.Close()
		db = nil
	}
}
